{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "churn_data = pd.read_csv('https://raw.githubusercontent.com/zekelabs/data-science-complete-tutorial/master/Data/churn.csv.txt',\n",
    "                         parse_dates=['last_trip_date','signup_date'] )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 50000 entries, 0 to 49999\n",
      "Data columns (total 12 columns):\n",
      "avg_dist                  50000 non-null float64\n",
      "avg_rating_by_driver      49799 non-null float64\n",
      "avg_rating_of_driver      41878 non-null float64\n",
      "avg_surge                 50000 non-null float64\n",
      "city                      50000 non-null object\n",
      "last_trip_date            50000 non-null datetime64[ns]\n",
      "phone                     49604 non-null object\n",
      "signup_date               50000 non-null datetime64[ns]\n",
      "surge_pct                 50000 non-null float64\n",
      "trips_in_first_30_days    50000 non-null int64\n",
      "luxury_car_user           50000 non-null bool\n",
      "weekday_pct               50000 non-null float64\n",
      "dtypes: bool(1), datetime64[ns](2), float64(6), int64(1), object(2)\n",
      "memory usage: 4.2+ MB\n"
     ]
    }
   ],
   "source": [
    "churn_data.info()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>avg_dist</th>\n",
       "      <th>avg_rating_by_driver</th>\n",
       "      <th>avg_rating_of_driver</th>\n",
       "      <th>avg_surge</th>\n",
       "      <th>city</th>\n",
       "      <th>last_trip_date</th>\n",
       "      <th>phone</th>\n",
       "      <th>signup_date</th>\n",
       "      <th>surge_pct</th>\n",
       "      <th>trips_in_first_30_days</th>\n",
       "      <th>luxury_car_user</th>\n",
       "      <th>weekday_pct</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.67</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.7</td>\n",
       "      <td>1.10</td>\n",
       "      <td>King's Landing</td>\n",
       "      <td>2014-06-17</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-25</td>\n",
       "      <td>15.4</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>46.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8.26</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>Astapor</td>\n",
       "      <td>2014-05-05</td>\n",
       "      <td>Android</td>\n",
       "      <td>2014-01-29</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>50.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.77</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.3</td>\n",
       "      <td>1.00</td>\n",
       "      <td>Astapor</td>\n",
       "      <td>2014-01-07</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-06</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>False</td>\n",
       "      <td>100.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2.36</td>\n",
       "      <td>4.9</td>\n",
       "      <td>4.6</td>\n",
       "      <td>1.14</td>\n",
       "      <td>King's Landing</td>\n",
       "      <td>2014-06-29</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-10</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9</td>\n",
       "      <td>True</td>\n",
       "      <td>80.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.13</td>\n",
       "      <td>4.9</td>\n",
       "      <td>4.4</td>\n",
       "      <td>1.19</td>\n",
       "      <td>Winterfell</td>\n",
       "      <td>2014-03-15</td>\n",
       "      <td>Android</td>\n",
       "      <td>2014-01-27</td>\n",
       "      <td>11.8</td>\n",
       "      <td>14</td>\n",
       "      <td>False</td>\n",
       "      <td>82.4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   avg_dist  avg_rating_by_driver  avg_rating_of_driver  avg_surge  \\\n",
       "0      3.67                   5.0                   4.7       1.10   \n",
       "1      8.26                   5.0                   5.0       1.00   \n",
       "2      0.77                   5.0                   4.3       1.00   \n",
       "3      2.36                   4.9                   4.6       1.14   \n",
       "4      3.13                   4.9                   4.4       1.19   \n",
       "\n",
       "             city last_trip_date    phone signup_date  surge_pct  \\\n",
       "0  King's Landing     2014-06-17   iPhone  2014-01-25       15.4   \n",
       "1         Astapor     2014-05-05  Android  2014-01-29        0.0   \n",
       "2         Astapor     2014-01-07   iPhone  2014-01-06        0.0   \n",
       "3  King's Landing     2014-06-29   iPhone  2014-01-10       20.0   \n",
       "4      Winterfell     2014-03-15  Android  2014-01-27       11.8   \n",
       "\n",
       "   trips_in_first_30_days  luxury_car_user  weekday_pct  \n",
       "0                       4             True         46.2  \n",
       "1                       0            False         50.0  \n",
       "2                       3            False        100.0  \n",
       "3                       9             True         80.0  \n",
       "4                      14            False         82.4  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "churn_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2014-07-01 00:00:00')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "churn_data.last_trip_date.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2014-01-01 00:00:00')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "churn_data.last_trip_date.min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "cutoff = churn_data.last_trip_date.max() - datetime.timedelta(30,0,0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Timestamp('2014-06-01 00:00:00')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cutoff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "churn_data['churn'] = churn_data.last_trip_date.map(lambda d: 1 if d > cutoff else 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>avg_dist</th>\n",
       "      <th>avg_rating_by_driver</th>\n",
       "      <th>avg_rating_of_driver</th>\n",
       "      <th>avg_surge</th>\n",
       "      <th>city</th>\n",
       "      <th>last_trip_date</th>\n",
       "      <th>phone</th>\n",
       "      <th>signup_date</th>\n",
       "      <th>surge_pct</th>\n",
       "      <th>trips_in_first_30_days</th>\n",
       "      <th>luxury_car_user</th>\n",
       "      <th>weekday_pct</th>\n",
       "      <th>churn</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3.67</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.7</td>\n",
       "      <td>1.10</td>\n",
       "      <td>King's Landing</td>\n",
       "      <td>2014-06-17</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-25</td>\n",
       "      <td>15.4</td>\n",
       "      <td>4</td>\n",
       "      <td>True</td>\n",
       "      <td>46.2</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>8.26</td>\n",
       "      <td>5.0</td>\n",
       "      <td>5.0</td>\n",
       "      <td>1.00</td>\n",
       "      <td>Astapor</td>\n",
       "      <td>2014-05-05</td>\n",
       "      <td>Android</td>\n",
       "      <td>2014-01-29</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0</td>\n",
       "      <td>False</td>\n",
       "      <td>50.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.77</td>\n",
       "      <td>5.0</td>\n",
       "      <td>4.3</td>\n",
       "      <td>1.00</td>\n",
       "      <td>Astapor</td>\n",
       "      <td>2014-01-07</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-06</td>\n",
       "      <td>0.0</td>\n",
       "      <td>3</td>\n",
       "      <td>False</td>\n",
       "      <td>100.0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2.36</td>\n",
       "      <td>4.9</td>\n",
       "      <td>4.6</td>\n",
       "      <td>1.14</td>\n",
       "      <td>King's Landing</td>\n",
       "      <td>2014-06-29</td>\n",
       "      <td>iPhone</td>\n",
       "      <td>2014-01-10</td>\n",
       "      <td>20.0</td>\n",
       "      <td>9</td>\n",
       "      <td>True</td>\n",
       "      <td>80.0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3.13</td>\n",
       "      <td>4.9</td>\n",
       "      <td>4.4</td>\n",
       "      <td>1.19</td>\n",
       "      <td>Winterfell</td>\n",
       "      <td>2014-03-15</td>\n",
       "      <td>Android</td>\n",
       "      <td>2014-01-27</td>\n",
       "      <td>11.8</td>\n",
       "      <td>14</td>\n",
       "      <td>False</td>\n",
       "      <td>82.4</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   avg_dist  avg_rating_by_driver  avg_rating_of_driver  avg_surge  \\\n",
       "0      3.67                   5.0                   4.7       1.10   \n",
       "1      8.26                   5.0                   5.0       1.00   \n",
       "2      0.77                   5.0                   4.3       1.00   \n",
       "3      2.36                   4.9                   4.6       1.14   \n",
       "4      3.13                   4.9                   4.4       1.19   \n",
       "\n",
       "             city last_trip_date    phone signup_date  surge_pct  \\\n",
       "0  King's Landing     2014-06-17   iPhone  2014-01-25       15.4   \n",
       "1         Astapor     2014-05-05  Android  2014-01-29        0.0   \n",
       "2         Astapor     2014-01-07   iPhone  2014-01-06        0.0   \n",
       "3  King's Landing     2014-06-29   iPhone  2014-01-10       20.0   \n",
       "4      Winterfell     2014-03-15  Android  2014-01-27       11.8   \n",
       "\n",
       "   trips_in_first_30_days  luxury_car_user  weekday_pct  churn  \n",
       "0                       4             True         46.2      1  \n",
       "1                       0            False         50.0      0  \n",
       "2                       3            False        100.0      0  \n",
       "3                       9             True         80.0      1  \n",
       "4                      14            False         82.4      0  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "churn_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_cols = list(churn_data.select_dtypes('object').columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['city', 'phone']"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_cols = list(churn_data.select_dtypes('float64').columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_cols.append('trips_in_first_30_days')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['avg_dist',\n",
       " 'avg_rating_by_driver',\n",
       " 'avg_rating_of_driver',\n",
       " 'avg_surge',\n",
       " 'surge_pct',\n",
       " 'weekday_pct',\n",
       " 'trips_in_first_30_days']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_cols"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Pipeline Creation for categorical columns & numerical columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import make_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_pipeline = make_pipeline(SimpleImputer(strategy='constant', fill_value='missing'),\n",
    "                            OneHotEncoder())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "cat_data_tf = cat_pipeline.fit_transform(churn_data[cat_cols]).toarray()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_pipeline = make_pipeline(SimpleImputer(strategy='median'),\n",
    "                             MinMaxScaler())                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_data_tf = num_pipeline.fit_transform(churn_data[num_cols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\awant\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:1: FutureWarning: reshape is deprecated and will raise in a subsequent release. Please use .values.reshape(...) instead\n",
      "  \"\"\"Entry point for launching an IPython kernel.\n"
     ]
    }
   ],
   "source": [
    "feature_data = np.hstack([cat_data_tf, num_data_tf, churn_data.luxury_car_user.reshape(-1,1)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000, 14)"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_data = churn_data.churn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.feature_selection import SelectKBest\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_pipeline = make_pipeline(SelectKBest(k=8), RandomForestClassifier(n_estimators=20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainX, testX, trainY, testY = train_test_split(feature_data,target_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('selectkbest', SelectKBest(k=8, score_func=<function f_classif at 0x000001F05973DB70>)), ('randomforestclassifier', RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
       "            max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
       "            min_impurity_decre...obs=None,\n",
       "            oob_score=False, random_state=None, verbose=0,\n",
       "            warm_start=False))])"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_pipeline.fit(trainX,trainY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.66072"
      ]
     },
     "execution_count": 48,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "final_pipeline.score(testX,testY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'selectkbest__k':[5,8,12],\n",
    "    'randomforestclassifier__n_estimators':[10,30,50]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_model = GridSearchCV(final_pipeline,param_grid=params, cv=5, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5, error_score='raise-deprecating',\n",
       "       estimator=Pipeline(memory=None,\n",
       "     steps=[('selectkbest', SelectKBest(k=8, score_func=<function f_classif at 0x000001F05973DB70>)), ('randomforestclassifier', RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n",
       "            max_depth=None, max_features='auto', max_leaf_nodes=None,\n",
       "            min_impurity_decre...obs=None,\n",
       "            oob_score=False, random_state=None, verbose=0,\n",
       "            warm_start=False))]),\n",
       "       fit_params=None, iid='warn', n_jobs=-1,\n",
       "       param_grid={'selectkbest__k': [5, 8, 12], 'randomforestclassifier__n_estimators': [10, 30, 50]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring=None, verbose=0)"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.fit(trainX,trainY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'randomforestclassifier__n_estimators': 50, 'selectkbest__k': 12}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7619733333333333"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7592"
      ]
     },
     "execution_count": 56,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.score(testX,testY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.compose import ColumnTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['avg_dist',\n",
       " 'avg_rating_by_driver',\n",
       " 'avg_rating_of_driver',\n",
       " 'avg_surge',\n",
       " 'surge_pct',\n",
       " 'weekday_pct',\n",
       " 'trips_in_first_30_days']"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_cols"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Using column transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', num_pipeline, num_cols),\n",
    "        ('cat', cat_pipeline, cat_cols)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "pipeline = Pipeline(steps=[('preprocessor',preprocessor),\n",
    "                           ('selectkbest', SelectKBest(k=8)),\n",
    "                           ('classifier',RandomForestClassifier(n_estimators=20))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "#pipeline = make_pipeline(preprocessor, SelectKBest(k=8), RandomForestClassifier(n_estimators=20))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline(memory=None,\n",
       "     steps=[('simpleimputer', SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
       "       strategy='median', verbose=0)), ('minmaxscaler', MinMaxScaler(copy=True, feature_range=(0, 1)))])"
      ]
     },
     "execution_count": 98,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "num_pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'preprocessor__num__simpleimputer__strategy':['mean','median'],\n",
    "    'selectkbest__k':[12,13],\n",
    "    'classifier__n_estimators':[50,70]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [],
   "source": [
    "grid_model = GridSearchCV(pipeline, param_grid=params,cv=5, n_jobs=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_data = churn_data.drop('churn',axis=1)\n",
    "target_data = churn_data.churn\n",
    "trainX, testX, trainY, testY = train_test_split(feature_data,target_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=5, error_score='raise-deprecating',\n",
       "       estimator=Pipeline(memory=None,\n",
       "     steps=[('preprocessor', ColumnTransformer(n_jobs=None, remainder='drop', sparse_threshold=0.3,\n",
       "         transformer_weights=None,\n",
       "         transformers=[('num', Pipeline(memory=None,\n",
       "     steps=[('simpleimputer', SimpleImputer(copy=True, fill_value=None, missing_values=nan,\n",
       "       strategy='median',...obs=None,\n",
       "            oob_score=False, random_state=None, verbose=0,\n",
       "            warm_start=False))]),\n",
       "       fit_params=None, iid='warn', n_jobs=-1,\n",
       "       param_grid={'preprocessor__num__simpleimputer__strategy': ['mean', 'median'], 'selectkbest__k': [12, 13], 'classifier__n_estimators': [50, 70]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring=None, verbose=0)"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.fit(trainX,trainY)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7528"
      ]
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.best_score_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'classifier__n_estimators': 70,\n",
       " 'preprocessor__num__simpleimputer__strategy': 'mean',\n",
       " 'selectkbest__k': 12}"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grid_model.best_params_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
